/**
 * Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */
/*!
* \file data_copy_wrapper.h
* \brief
*/

#ifndef IMPL_MATMUL_MODULES_STAGE_COPY_CUBE_IN_DATA_COPY_WRAPPER_H
#define IMPL_MATMUL_MODULES_STAGE_COPY_CUBE_IN_DATA_COPY_WRAPPER_H

#include "../../matmul_module.h"
#include "../../matmul_param.h"
#include "../../param/matmul_var.h"
#include "copy_cube_in_utils.h"
#include "copy_cube_in_params.h"

namespace matmul {

using namespace AscendC;

template<typename IMPL, const auto& MM_CFG, class INPUT_TYPE>
class DataCopyWrapper {
    using TransT = typename INPUT_TYPE::TRANS_T;
    using SrcT = typename INPUT_TYPE::T;

    MATMUL_USE_MODULE_ON(MatmulVar, INPUT_TYPE::TAG);
    MATMUL_USE_MODULE_ON(MatmulShapeInfo, INPUT_TYPE::TAG);
    MATMUL_USE_MODULE_ON(CopyCubeInParams, INPUT_TYPE::TAG);
    MATMUL_USE_MODULE_ON(MatmulTensorInfo, INPUT_TYPE::TAG);

    template <bool IS_TRANS = false, typename INPUT_TYPE_ALIAS = INPUT_TYPE>
    __aicore__ constexpr enable_if_t<INPUT_TYPE_ALIAS::TAG == InputTypeTag::A, int32_t> GetStaticTileHeight() const
    {
        if constexpr ((INPUT_TYPE_ALIAS::layout != LayoutMode::NONE) &&
            (ToMatmulConfig(MM_CFG).batchMode != BatchMode::SINGLE_LARGE_THAN_L1)) {
            if constexpr (IS_TRANS) {
                return MATMUL_CONST_PARAM_VAR.tiling_.GetSingleCoreK();
            } else {
                return MATMUL_CONST_PARAM_VAR.tiling_.GetSingleCoreM();
            }
        } else if constexpr (DoMatmulMDL(MM_CFG) || DoMatmulSpecialMDL(MM_CFG)) {
            if constexpr (IS_TRANS) {
                return MATMUL_CONST_PARAM_VAR.tiling_.GetStepKa() * MATMUL_CONST_PARAM_VAR.tiling_.GetBaseK();
            } else {
                return MATMUL_CONST_PARAM_VAR.tiling_.GetStepM() * MATMUL_CONST_PARAM_VAR.tiling_.GetBaseM();
            }
        } else {
            return MATMUL_MODULE(MatmulShapeInfo)->template GetBaseHeight<IS_TRANS>();
        }
    }

    template <bool IS_TRANS = false, typename INPUT_TYPE_ALIAS = INPUT_TYPE>
    __aicore__ constexpr enable_if_t<INPUT_TYPE_ALIAS::TAG == InputTypeTag::A, int32_t> GetStaticTileWidth() const
    {
        if constexpr ((INPUT_TYPE_ALIAS::layout != LayoutMode::NONE) &&
            (ToMatmulConfig(MM_CFG).batchMode != BatchMode::SINGLE_LARGE_THAN_L1)) {
            if constexpr (IS_TRANS) {
                return MATMUL_CONST_PARAM_VAR.tiling_.GetSingleCoreM();
            } else {
                return MATMUL_CONST_PARAM_VAR.tiling_.GetSingleCoreK();
            }
        } else if constexpr (DoMatmulMDL(MM_CFG) || DoMatmulSpecialMDL(MM_CFG)) {
            if constexpr (IS_TRANS) {
                return MATMUL_CONST_PARAM_VAR.tiling_.GetStepM() * MATMUL_CONST_PARAM_VAR.tiling_.GetBaseM();
            } else {
                return MATMUL_CONST_PARAM_VAR.tiling_.GetStepKa() * MATMUL_CONST_PARAM_VAR.tiling_.GetBaseK();
            }
        } else {
            return MATMUL_MODULE(MatmulShapeInfo)->template GetBaseWidth<IS_TRANS>();
        }
    }

    template <bool IS_TRANS = false, typename INPUT_TYPE_ALIAS = INPUT_TYPE>
    __aicore__ inline enable_if_t<INPUT_TYPE_ALIAS::TAG == InputTypeTag::B, int32_t> GetStaticTileHeight() const
    {
        if constexpr ((INPUT_TYPE_ALIAS::layout != LayoutMode::NONE) &&
            (ToMatmulConfig(MM_CFG).batchMode != BatchMode::SINGLE_LARGE_THAN_L1)) {
            if constexpr (IS_TRANS) {
                return MATMUL_CONST_PARAM_VAR.tiling_.GetSingleCoreN();
            } else {
                return MATMUL_CONST_PARAM_VAR.tiling_.GetSingleCoreK();
            }
        } else if constexpr (DoMatmulMDL(MM_CFG) || DoMatmulSpecialMDL(MM_CFG)) {
            if constexpr (IS_TRANS) {
                return MATMUL_CONST_PARAM_VAR.tiling_.GetStepN() * MATMUL_CONST_PARAM_VAR.tiling_.GetBaseN();
            } else {
                return MATMUL_CONST_PARAM_VAR.tiling_.GetStepKb() * MATMUL_CONST_PARAM_VAR.tiling_.GetBaseK();
            }
        } else {
            return MATMUL_MODULE(MatmulShapeInfo)->template GetBaseHeight<IS_TRANS>();
        }
    }

    template <bool IS_TRANS = false, typename INPUT_TYPE_ALIAS = INPUT_TYPE>
    __aicore__ inline enable_if_t<INPUT_TYPE_ALIAS::TAG == InputTypeTag::B, int32_t> GetStaticTileWidth() const
    {
        if constexpr ((INPUT_TYPE_ALIAS::layout != LayoutMode::NONE) &&
            (ToMatmulConfig(MM_CFG).batchMode != BatchMode::SINGLE_LARGE_THAN_L1)) {
            if constexpr (IS_TRANS) {
                return MATMUL_CONST_PARAM_VAR.tiling_.GetSingleCoreK();
            } else {
                return MATMUL_CONST_PARAM_VAR.tiling_.GetSingleCoreN();
            }
        } else if constexpr (DoMatmulMDL(MM_CFG) || DoMatmulSpecialMDL(MM_CFG)) {
            if constexpr (IS_TRANS) {
                return MATMUL_CONST_PARAM_VAR.tiling_.GetStepKb() * MATMUL_CONST_PARAM_VAR.tiling_.GetBaseK();
            } else {
                return MATMUL_CONST_PARAM_VAR.tiling_.GetStepN() * MATMUL_CONST_PARAM_VAR.tiling_.GetBaseN();
            }
        } else {
            return MATMUL_MODULE(MatmulShapeInfo)->template GetBaseWidth<IS_TRANS>();
        }
    }

public:
    __aicore__ inline DataCopyWrapper() = default;
    __aicore__ inline ~DataCopyWrapper() = default;

    template <bool IS_INTRA_BLOCK = false>
    __aicore__ inline void CopyTileToCube(const LocalTensor<TransT>& dst, int32_t curRow, int32_t curCol,
        int32_t tileHeight, int32_t tileWidth)
    {
#ifdef ASCENDC_CPU_DEBUG
        if (INPUT_TYPE::TAG == InputTypeTag::A && IMPL::CallBack::CopyA1Ptr) {
            LocalTensor<int8_t> tmpDst = dst.template ReinterpretCast<int8_t>();
            (IMPL::CallBack::CopyA1Ptr)(tmpDst,
                reinterpret_cast<__gm__ void *>(MATMUL_MODULE(MatmulTensorInfo)->template GetGlobalAddr<IS_INTRA_BLOCK>()),
                curRow, curCol, tileHeight, tileWidth, MATMUL_MODULE(MatmulTensorInfo)->GetUserDefineInfo(),
                MATMUL_MODULE(MatmulTensorInfo)->GetSelfDefineData());
        } else if (INPUT_TYPE::TAG == InputTypeTag::B && IMPL::CallBack::CopyB1Ptr) {
            LocalTensor<int8_t> tmpDst = dst.template ReinterpretCast<int8_t>();
            (IMPL::CallBack::CopyB1Ptr)(tmpDst,
                reinterpret_cast<__gm__ void *>(MATMUL_MODULE(MatmulTensorInfo)->template GetGlobalAddr<IS_INTRA_BLOCK>()),
                curRow, curCol, tileHeight, tileWidth, MATMUL_MODULE(MatmulTensorInfo)->GetUserDefineInfo(),
                MATMUL_MODULE(MatmulTensorInfo)->GetSelfDefineData());
#else
        if constexpr (INPUT_TYPE::TAG == InputTypeTag::A && IMPL::CallBack::CopyA1Ptr) {
            LocalTensor<int8_t> tmpDst = dst.template ReinterpretCast<int8_t>();
            (IMPL::CallBack::CopyA1Ptr)(tmpDst,
                reinterpret_cast<__gm__ void *>(MATMUL_MODULE(MatmulTensorInfo)->template GetGlobalAddr<IS_INTRA_BLOCK>()),
                curRow, curCol, tileHeight, tileWidth, MATMUL_MODULE(MatmulTensorInfo)->GetUserDefineInfo(),
                MATMUL_MODULE(MatmulTensorInfo)->GetSelfDefineData());
        } else if constexpr (INPUT_TYPE::TAG == InputTypeTag::B && IMPL::CallBack::CopyB1Ptr) {
            LocalTensor<int8_t> tmpDst = dst.template ReinterpretCast<int8_t>();
            (IMPL::CallBack::CopyB1Ptr)(tmpDst,
                reinterpret_cast<__gm__ void *>(MATMUL_MODULE(MatmulTensorInfo)->template GetGlobalAddr<IS_INTRA_BLOCK>()),
                curRow, curCol, tileHeight, tileWidth, MATMUL_MODULE(MatmulTensorInfo)->GetUserDefineInfo(),
                MATMUL_MODULE(MatmulTensorInfo)->GetSelfDefineData());
#endif
        } else {
            constexpr int32_t widthFactor =
                IsSameTypeV<TransT, int4b_t> && INPUT_TYPE::format == CubeFormat::ND ? INT4_TWO : 1;
            if (MATMUL_MODULE(MatmulShapeInfo)->template IsTranspose<IS_INTRA_BLOCK>()) {
                if constexpr (IsCopyFromUB<INPUT_TYPE, MM_CFG>()) {
                    LocalTensor<SrcT> src;
                    src.SetAddr(MATMUL_MODULE(MatmulTensorInfo)->GetLocalAddr());
                    CopyTileToCubeFromUB(
                        dst, src, curCol, curRow, tileWidth, tileHeight / widthFactor,
                        MATMUL_MODULE(MatmulShapeInfo)->template GetBaseHeight<true>(),
                        MATMUL_MODULE(MatmulShapeInfo)->template GetBaseWidth<true>(),
                        MATMUL_MODULE(MatmulShapeInfo)->template GetOrgHeight<true, IS_INTRA_BLOCK>(),
                        MATMUL_MODULE(MatmulShapeInfo)->template GetOrgWidth<true, IS_INTRA_BLOCK>() / widthFactor,
                        MATMUL_MODULE(MatmulShapeInfo)->template IsKRowDirec<IS_INTRA_BLOCK>());
                } else {
                    GlobalTensor<SrcT> src;
                    src.SetGlobalBuffer(MATMUL_MODULE(MatmulTensorInfo)->template GetGlobalAddr<IS_INTRA_BLOCK>());
                    CopyTileToCubeFromGM(dst, src, curCol, curRow, tileWidth, tileHeight / widthFactor,
                        MATMUL_MODULE(MatmulShapeInfo)->template GetBaseHeight<true>(),
                        MATMUL_MODULE(MatmulShapeInfo)->template GetBaseWidth<true>(),
                        MATMUL_MODULE(MatmulShapeInfo)->template GetOrgHeight<true, IS_INTRA_BLOCK>(),
                        MATMUL_MODULE(MatmulShapeInfo)->template GetOrgWidth<true, IS_INTRA_BLOCK>() / widthFactor,
                        MATMUL_MODULE(CopyCubeInParams)->template GetStepCol<false>(),
                        MATMUL_MODULE(MatmulShapeInfo)->template IsKRowDirec<IS_INTRA_BLOCK>());
                }
            } else {
                if constexpr (IsCopyFromUB<INPUT_TYPE, MM_CFG>()) {
                    LocalTensor<SrcT> src;
                    src.SetAddr(MATMUL_MODULE(MatmulTensorInfo)->GetLocalAddr());
                    CopyTileToCubeFromUB(
                        dst, src, curRow, curCol, tileHeight, tileWidth / widthFactor,
                        MATMUL_MODULE(MatmulShapeInfo)->template GetBaseHeight<false>(),
                        MATMUL_MODULE(MatmulShapeInfo)->template GetBaseWidth<false>(),
                        MATMUL_MODULE(MatmulShapeInfo)->template GetOrgHeight<false, IS_INTRA_BLOCK>(),
                        MATMUL_MODULE(MatmulShapeInfo)->template GetOrgWidth<false, IS_INTRA_BLOCK>() / widthFactor,
                        MATMUL_MODULE(MatmulShapeInfo)->template IsKRowDirec<IS_INTRA_BLOCK>());
                } else {
                    GlobalTensor<SrcT> src;
                    src.SetGlobalBuffer(MATMUL_MODULE(MatmulTensorInfo)->template GetGlobalAddr<IS_INTRA_BLOCK>());
                    CopyTileToCubeFromGM(
                        dst, src, curRow, curCol, tileHeight, tileWidth / widthFactor,
                        MATMUL_MODULE(MatmulShapeInfo)->template GetBaseHeight<false>(),
                        MATMUL_MODULE(MatmulShapeInfo)->template GetBaseWidth<false>(),
                        MATMUL_MODULE(MatmulShapeInfo)->template GetOrgHeight<false, IS_INTRA_BLOCK>(),
                        MATMUL_MODULE(MatmulShapeInfo)->template GetOrgWidth<false, IS_INTRA_BLOCK>() / widthFactor,
                        MATMUL_MODULE(CopyCubeInParams)->template GetStepCol<false>(),
                        MATMUL_MODULE(MatmulShapeInfo)->template IsKRowDirec<IS_INTRA_BLOCK>());
                }
            }
        }
    }

private:
    constexpr static int32_t c0Size_ = AuxGetC0Size<SrcT>();

    __aicore__ bool IsTailTile(int tileHeight, int tileWidth)
    {
        if (MATMUL_MODULE(MatmulShapeInfo)->IsTranspose()) {
            return GetStaticTileHeight<true>() != tileHeight || GetStaticTileWidth<true>() != tileWidth;
        } else {
            return GetStaticTileHeight<false>() != tileHeight || GetStaticTileWidth<false>() != tileWidth;
        }
    }

    template <typename DataType>
    __aicore__ inline void StaticPadNd2Nz(const LocalTensor<DataType>& dst, const int32_t staticHeight,
        const int32_t staticWidth, const int32_t tileHeight, const int32_t tileWidth)
    {
        if constexpr (DoMatmulNorm(MM_CFG) || DoMatmulBasicBlock(MM_CFG) || DoMatmulSpecialBasicBlock(MM_CFG)) {
            int32_t tileWidthC0 = Ceil(tileWidth, c0Size_);
            int32_t staticWidthC0 = Ceil(staticWidth, c0Size_);
            // pad left bottom area of src.
            if (tileHeight < staticHeight) {
                InitConstValueParams<DataType> initConstValueParams;
                initConstValueParams.repeatTimes = tileWidthC0;
                initConstValueParams.blockNum = staticHeight - tileHeight;
                initConstValueParams.dstGap = tileHeight;
                initConstValueParams.initValue = 0;
                InitConstValue(dst[tileHeight * c0Size_], initConstValueParams);
            }
            // pad right area of src
            if (tileWidthC0 < staticWidthC0) {
                InitConstValueParams<DataType> initConstValueParams;
                initConstValueParams.repeatTimes = 1;
                initConstValueParams.blockNum = (staticWidthC0 - tileWidthC0) * staticHeight;
                initConstValueParams.dstGap = 0;
                initConstValueParams.initValue = 0;
                InitConstValue(dst[tileWidthC0 * staticHeight * c0Size_], initConstValueParams);
            }
        } else if constexpr (DoMatmulMDL(MM_CFG) || DoMatmulSpecialMDL(MM_CFG)) {
            using params = InitConstValueParams<DataType>;
            InitConstValue(dst,
                params{ 1, static_cast<uint16_t>(staticHeight * staticWidth * sizeof(DataType) / ONE_BLK_SIZE), 0, 0 });
        }
    }

    __aicore__ inline void CopyTileToCubeFromGM(const LocalTensor<TransT>& dst, const GlobalTensor<SrcT>& src,
        int32_t curRow, int32_t curCol, int32_t tileHeight, int32_t tileWidth, int32_t baseHeight, int32_t baseWidth,
        int32_t orgHeight, int32_t orgWidth, int32_t stepCol, bool iskRowDirec)
    {
        if constexpr (INPUT_TYPE::format == CubeFormat::ND) {
            if constexpr (sizeof(TransT) == sizeof(int8_t)) {
                if (tileWidth < baseWidth || baseWidth % c0Size_ == 0 || stepCol == 1) {
                    CopyND2NZ(dst, src, curRow * baseHeight, curCol * baseWidth, tileHeight,
                        tileWidth, orgWidth, 1, 0, 0, iskRowDirec);
                } else {
                    CopyND2NZ(dst, src, curRow * baseHeight,
                        curCol * baseWidth, tileHeight, baseWidth, orgWidth, stepCol - 1, baseWidth,
                        CeilAlign(baseWidth, c0Size_) * CeilAlign(tileHeight, c0Size_), iskRowDirec);
                    CopyND2NZ(dst[(stepCol - 1) * CeilAlign(baseWidth, c0Size_) * CeilAlign(tileHeight, c0Size_)], src,
                        curRow * baseHeight, (curCol + stepCol -1) * baseWidth, tileHeight, 
                        tileWidth - (stepCol - 1) * baseWidth, orgWidth, 1, 0, 0, iskRowDirec);
                }
            } else {
                CopyND2NZ(dst, src, curRow * baseHeight, curCol * baseWidth, tileHeight, tileWidth, orgWidth);
            }
        } else if constexpr (INPUT_TYPE::format == CubeFormat::NZ) {
            CopyNZ2NZ(dst, src,
                curRow * baseHeight, curCol * baseWidth, tileHeight, tileWidth, orgHeight, iskRowDirec);
        } else if constexpr (INPUT_TYPE::format == CubeFormat::VECTOR) {
            CopyVector2A1(dst, src, curCol * baseWidth, Ceil(tileWidth, c0Size_));
        } else if constexpr (INPUT_TYPE::format == CubeFormat::SCALAR) {
            return;
        } else {
            ASCENDC_ASSERT(false,
                { KERNEL_LOG(KERNEL_ERROR, "MatmulApi only support input format ND/NZ/VECTOR/SCALAR."); });
        }
    }

    __aicore__ inline void CopyTileToCubeFromUB(const LocalTensor<TransT>& dst, const LocalTensor<SrcT>& src,
        int32_t curRow, int32_t curCol, int32_t tileHeight, int32_t tileWidth, int32_t baseHeight, int32_t baseWidth,
        int32_t orgHeight, int32_t orgWidth, bool iskRowDirec)
    {
#if __CCE_AICORE__ != 300
        ASCENDC_ASSERT(false, { KERNEL_LOG(KERNEL_ERROR, "CopyTileToCubeFromUB only support input from UB."); });
#else
        if constexpr (INPUT_TYPE::format == CubeFormat::ND) {
            CopyND2NZ(dst, src, curRow * baseHeight, curCol * baseWidth, tileHeight, tileWidth, orgWidth);
        } else if constexpr (INPUT_TYPE::format == CubeFormat::NZ) {
            CopyNZ2NZ(dst, src, curRow * baseHeight, curCol * baseWidth, tileHeight, tileWidth, orgHeight);
        } else if constexpr (INPUT_TYPE::format == CubeFormat::VECTOR) {
            ASCENDC_ASSERT(false, { KERNEL_LOG(KERNEL_ERROR,
            "When input format is VECTOR, only support A transpose and B untranspose."); });
            CopyVector2A1(dst, src, curCol * baseWidth, Ceil(tileWidth, c0Size_));
        } else if constexpr (INPUT_TYPE::format == CubeFormat::SCALAR) {
            return;
        } else {
            ASCENDC_ASSERT(false,
                { KERNEL_LOG(KERNEL_ERROR, "MatmulApi only support input format ND/NZ/VECTOR/SCALAR."); });
        }
#endif
    }

    __aicore__ inline void CopyNZ2NZ(const LocalTensor<TransT>& dst, const GlobalTensor<SrcT>& src,
        const int32_t row, const int32_t col, const int32_t height, const int32_t width, const int32_t gRow,
        const bool kAlignToC0Size = false)
    {
        ASCENDC_ASSERT((gRow >= height), {
            KERNEL_LOG(KERNEL_ERROR,
                "NZ2NZ height larger than origin matrix height, gRow is %d, which should be no less than height %d.",
                gRow, height);
        });
        int32_t alignedGRow = Ceil(gRow, BLOCK_CUBE) * BLOCK_CUBE;
        int64_t srcOffset = (int64_t)row * (int64_t)c0Size_ + (int64_t)col * (int64_t)alignedGRow;
        // height direction need to be 16 aligned
        auto alignHeight = Ceil(height, BLOCK_CUBE) * BLOCK_CUBE;
        int32_t blockLen = alignHeight * c0Size_ * sizeof(TransT) / ONE_BLK_SIZE;
        int32_t srcStride = (alignedGRow - alignHeight) * (c0Size_ * sizeof(TransT) / ONE_BLK_SIZE);
        if constexpr (IsSameTypeV<TransT, int4b_t>) {
            blockLen /= INT4_TWO;
            srcStride /= INT4_TWO;
        }
        if (srcStride >= UINT16_MAX) {
            for (int32_t i = 0; i < Ceil(width, c0Size_); ++i) {
                DataCopy(dst[i * alignHeight * c0Size_], src[srcOffset + i * gRow * c0Size_],
                    { 1, static_cast<uint16_t>(blockLen), 0, 0 });
            }
        } else {
            uint16_t nburst = Ceil(width, c0Size_);
            int32_t dstStride = 0;
            if constexpr (IsSameTypeV<TransT, int8_t>) {
                if (kAlignToC0Size) {
                    auto alignHeightC0Size = Ceil(height, c0Size_) * c0Size_;
                    dstStride = alignHeightC0Size - alignHeight;
                }
            }
            DataCopy(dst, src[srcOffset], { nburst, static_cast<uint16_t>(blockLen), static_cast<uint16_t>(srcStride),
                static_cast<uint16_t>(dstStride) });
        }
    };

    __aicore__ inline void CopyNZ2NZ(const LocalTensor<TransT>& dst, const LocalTensor<SrcT>& src,
        const int32_t row, const int32_t col, const int32_t height, const int32_t width, const int32_t gRow)
    {
        ASCENDC_ASSERT((gRow >= height),
                    { KERNEL_LOG(KERNEL_ERROR, "gRow is %d, which should be no less than height %d.", gRow, height); });
        int32_t srcOffset = row * c0Size_ + col * gRow;
        // height direction need to be 16 aligned
        auto alignHeight = (height + 15) / 16 * 16;
        int32_t blockLen = alignHeight * c0Size_ * sizeof(TransT) / ONE_BLK_SIZE;
        int32_t srcStride = (gRow - alignHeight) * (c0Size_ * sizeof(TransT) / ONE_BLK_SIZE);

        if (srcStride >= UINT16_MAX) {
            for (int32_t i = 0; i < width / c0Size_; ++i) {
                DataCopy(dst[i * alignHeight * c0Size_], src[srcOffset + i * gRow * c0Size_],
                    { 1, static_cast<uint16_t>(blockLen), 0, 0 });
            }
        } else {
            DataCopy(dst, src[srcOffset],
                { static_cast<uint16_t>(width / c0Size_), static_cast<uint16_t>(blockLen),
                static_cast<uint16_t>(srcStride), 0 });
        }
    };

    __aicore__ inline void CopyVector2A1(
        const LocalTensor<TransT>& dst, const GlobalTensor<SrcT>& src, const int32_t col, const int32_t blockLen)
    {
        ASCENDC_ASSERT((col >= 0), { KERNEL_LOG(KERNEL_ERROR, "col is %d, which should be no less than 0.", col); });
        ASCENDC_ASSERT((INPUT_TYPE::format == CubeFormat::VECTOR),
                    { KERNEL_LOG(KERNEL_ERROR, "INPUT_TYPE::format should be CubeFormat::VECTOR."); });

        DataCopyParams dataCopyInfo;
        dataCopyInfo.blockCount = 1;
        dataCopyInfo.blockLen = blockLen;
        dataCopyInfo.srcStride = 0;
        dataCopyInfo.dstStride = 0;
        DataCopyEnhancedParams enhancedParams;
        enhancedParams.blockMode = BlockMode::BLOCK_MODE_VECTOR;
        DataCopy(dst, src[col], dataCopyInfo, enhancedParams);
        return;
    };

    __aicore__ inline void CopyVector2A1(const LocalTensor<TransT>& dst, const LocalTensor<SrcT>& src,
        const int32_t col, const int32_t blockLen)
    {
        ASCENDC_ASSERT((col >= 0), { KERNEL_LOG(KERNEL_ERROR, "col is %d, which should be no less than 0.", col); });
        ASCENDC_ASSERT((INPUT_TYPE::format == CubeFormat::VECTOR),
                    { KERNEL_LOG(KERNEL_ERROR, "INPUT_TYPE::format should be CubeFormat::VECTOR."); });

        DataCopyParams dataCopyInfo;
        dataCopyInfo.blockCount = 1;
        dataCopyInfo.blockLen = blockLen;
        dataCopyInfo.srcStride = 0;
        dataCopyInfo.dstStride = 0;
        DataCopy(dst, src[col], dataCopyInfo);
        return;
    };

#if __CCE_AICORE__ >= 220
    __aicore__ inline void CopyND2NZ(const LocalTensor<TransT>& dst, const GlobalTensor<SrcT>& src,
        const int32_t row, const int32_t col, const int32_t height, const int32_t width, const int32_t gCol,
        const int32_t ndNum = 1, const int32_t srcNdMatrixStride = 0, const int32_t dstNzMatrixStride = 0,
        const bool kAlignToC0Size = false)
    {
        ASCENDC_ASSERT((row >= 0), { KERNEL_LOG(KERNEL_ERROR, "row is %d, which should be no less than 0.", row); });
        ASCENDC_ASSERT((col >= 0), { KERNEL_LOG(KERNEL_ERROR, "col is %d, which should be no less than 0.", col); });
        ASCENDC_ASSERT((height > 0),
            { KERNEL_LOG(KERNEL_ERROR, "height is %d, which should be no less than 0.", height); });
        ASCENDC_ASSERT((width > 0),
            { KERNEL_LOG(KERNEL_ERROR, "width is %d, which should be no less than 0.", width); });
        ASCENDC_ASSERT((gCol >= width), {
            KERNEL_LOG(KERNEL_ERROR,
                "ND2NZ width larger than origin matrix width, gCol is %d, which should be no less than width %d.", gCol,
                width);
        });
        int32_t dstNzC0Stride = 0;
        if constexpr (IsStaticPaddingEnable(MM_CFG)) {
            int32_t tileHeight = GetStaticTileHeight<INPUT_TYPE::isTrans>();
            int32_t tileWidth = GetStaticTileWidth<INPUT_TYPE::isTrans>();
            if (tileHeight != height || tileWidth != width) {
                StaticPadNd2Nz<TransT>(dst, tileHeight, tileWidth, height, width);
                dstNzC0Stride = tileHeight;
            }
        }
        int64_t srcOffset;
        if constexpr (IsSameTypeV<TransT, int4b_t>) {
            srcOffset = ((int64_t)row * (int64_t)gCol * INT4_TWO + (int64_t)col);
        } else {
            srcOffset = ((int64_t)row * (int64_t)gCol  + (int64_t)col);
        }
        Nd2NzParams nd2nzParams;
        nd2nzParams.ndNum = ndNum;
        nd2nzParams.nValue = height;
        nd2nzParams.dValue = width;
        nd2nzParams.srcNdMatrixStride = srcNdMatrixStride;
        nd2nzParams.srcDValue = gCol;

        if (dstNzC0Stride) {
            nd2nzParams.dstNzC0Stride = dstNzC0Stride;
        } else {
            // when k is row(height) axis, int8 type gm->l1 nd2nz should be aligned to 32(c0Size)
            // while float/half type should be aligned to 16
            if (kAlignToC0Size) {
                nd2nzParams.dstNzC0Stride = Ceil(height, c0Size_) * c0Size_;
            } else {
                nd2nzParams.dstNzC0Stride = Ceil(height, BLOCK_CUBE) * BLOCK_CUBE;
            }
        }
        nd2nzParams.dstNzNStride = 1;
        nd2nzParams.dstNzMatrixStride = dstNzMatrixStride;
    #if __CCE_AICORE__ == 220
        if constexpr (!ToMatmulConfig(MM_CFG).intrinsicsCheck) {
            DataCopy(dst, src[srcOffset], nd2nzParams);
        } else {
            if (gCol >= UINT16_MAX) {
                nd2nzParams.nValue = 1;
                nd2nzParams.srcDValue = width;
                for (int32_t i = 0; i < height; ++i) {
                    DataCopy(dst[i * c0Size_], src[srcOffset + gCol * i], nd2nzParams);
                }
            } else {
                DataCopy(dst, src[srcOffset], nd2nzParams);
            }
        }
    #else
        DataCopy(dst, src[srcOffset], nd2nzParams); // stride scope has increased
    #endif
    }

    __aicore__ inline void CopyND2NZ(const LocalTensor<TransT>& dst, const LocalTensor<SrcT>& src,
        const int32_t row, const int32_t col, const int32_t height, const int32_t width, const int32_t gCol)
    {
        ASSERT(gCol >= width && "Copy ND block ub->ub width larger than origin matrix width.");
        int32_t calcWidth = width / c0Size_; // cube block numbers that do not need to be pad zero
        int32_t tail = width % c0Size_;
        int32_t dstOffset = 0;
        int32_t srcOffset = row * gCol + col;
        int32_t calcWidthExr = Ceil(width, c0Size_);
        int32_t calcHeightExr = Ceil(height, BLOCK_CUBE);

        DataCopyEnhancedParams enhancedParams;
        enhancedParams.blockMode = BlockMode::BLOCK_MODE_VECTOR;

        int32_t srcStride = gCol * sizeof(SrcT) / ONE_BLK_SIZE - 1;
        if (gCol % c0Size_ || srcStride >= UINT16_MAX) {
            // each block len is only 32B
            for (int32_t i = 0; i < calcWidth; i++) {
                for (int32_t j = 0; j < height; j++) {
                    DataCopy(dst[dstOffset], src[srcOffset], { 1, 1, 0, 0 }, enhancedParams);
                    dstOffset += c0Size_;
                    srcOffset += gCol;
                }
                srcOffset += c0Size_;
            }
        } else {
            // data copy stride is aligned
            for (int32_t i = 0; i < calcWidth; i++) {
                DataCopy(dst[dstOffset], src[srcOffset],
                    { static_cast<uint16_t>(height), 1, static_cast<uint16_t>(srcStride), 0 }, enhancedParams);
                dstOffset += calcHeightExr * BLOCK_CUBE * c0Size_;
                srcOffset += c0Size_;
            }
        }
    }
#endif
};
}      // namespace matmul
#endif // IMPL_MATMUL_MODULES_STAGE_COPY_CUBE_IN_DATA_COPY_WRAPPER_H
